{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "ab835cce",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['S大写:1882=柒伍贰捌柒伍贰捌柒伍贰捌柒伍贰捌EP',\n",
       " 'S字母:6947=wuuopuuopuuopuuiiE',\n",
       " 'S数字:1913=7652765276527652EP',\n",
       " 'S字母:3075=qwepqwepqwepqweppE',\n",
       " 'S小写:3408=一三六三三三六三三三六三三三六三二E',\n",
       " 'S数字:6841=27366736673667364E',\n",
       " 'S数字:2796=11185118511851184E',\n",
       " 'S大写:5667=贰贰陆柒零贰陆柒零贰陆柒零贰陆陆捌E',\n",
       " 'S字母:8800=etwpetwpetwpetwppE',\n",
       " 'S大写:6691=贰陆柒陆陆陆柒陆陆陆柒陆陆陆柒陆肆E']"
      ]
     },
     "execution_count": 1,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "%run common.ipynb\n",
    "\n",
    "[tokenizer.decode(i) for i in tokenizer.get_batch_data(prefix=True)[1]][:10]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "bcb3d7c8",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/root/anaconda3/envs/pt2/lib/python3.10/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n"
     ]
    }
   ],
   "source": [
    "model_ppo = ModelPPO(torch.load('gen.model'))\n",
    "model_ppo_ref = ModelPPO(torch.load('gen.model'))\n",
    "\n",
    "for i in model_ppo_ref.parameters():\n",
    "    i.requires_grad_(False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "4b5f88e2",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([-0.6798, -0.9093, -0.2102,  0.4281, -0.0594,  1.7366, -0.4399,  0.5919,\n",
       "         1.0430, -1.0704, -0.7517, -1.2599,  1.4613, -1.1815,  0.5282])"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "def get_kl(a, b):\n",
    "    method = 'kl'\n",
    "\n",
    "    if method == 'kl':\n",
    "        return a - b\n",
    "\n",
    "    if method == 'abs':\n",
    "        return (a - b).abs()\n",
    "\n",
    "    if method == 'mse':\n",
    "        return (a - b).square() * 0.5\n",
    "\n",
    "    if method == 'full':\n",
    "        return torch.nn.functional.kl_div(a,\n",
    "                                          b,\n",
    "                                          log_target=True,\n",
    "                                          reduction='none')\n",
    "\n",
    "\n",
    "get_kl(torch.randn(15), torch.zeros(15))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "9d600b60",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "<__main__.PPOTrainer at 0x7f7f1fb52fe0>"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from trl.core import clip_by_value, logprobs_from_logits, masked_mean, masked_whiten\n",
    "\n",
    "\n",
    "class PPOTrainer:\n",
    "\n",
    "    def __init__(self):\n",
    "        self.optimizer = torch.optim.Adam(model_ppo.parameters(), lr=1e-5)\n",
    "\n",
    "    def step(self, question, answer, reward):\n",
    "        with torch.no_grad():\n",
    "            #编码\n",
    "            token = [q.tolist() + a.tolist() for q, a in zip(question, answer)]\n",
    "            input_ids, attention_mask = tokenizer.batch_pad(token=token)\n",
    "            del token\n",
    "            input_ids = torch.LongTensor(input_ids).to(device)\n",
    "            attention_mask = torch.LongTensor(attention_mask).to(device)\n",
    "\n",
    "            #question和answer不需要内容,只需要长度信息即可\n",
    "            lens_q = [len(i) for i in question]\n",
    "            lens_a = [len(i) for i in answer]\n",
    "            del question\n",
    "            del answer\n",
    "\n",
    "            #根据question计算answer的概率,并计算每个动作的分数\n",
    "            prob_log, value, mask = self.batched_forward_pass(\n",
    "                model_ppo, input_ids, attention_mask, lens_q, lens_a)\n",
    "\n",
    "            #使用ref模型计算概率,这是为了计算kl散度\n",
    "            prob_log_ref, _, _ = self.batched_forward_pass(\n",
    "                model_ppo_ref, input_ids, attention_mask, lens_q, lens_a)\n",
    "\n",
    "            #计算两份概率的kl散度,并融入reward\n",
    "            reward = self.compute_rewards(reward, prob_log, prob_log_ref, mask)\n",
    "\n",
    "            #计算delta和target,用于计算loss\n",
    "            value, delta, target = self.compute_advantages(value, reward, mask)\n",
    "\n",
    "        #每批数据循环N次模型\n",
    "        for _ in range(4):\n",
    "            #每次算一个数据\n",
    "            for i in range(len(input_ids)):\n",
    "                #重新计算概率和value\n",
    "                prob_log_new, value_new, _ = self.batched_forward_pass(\n",
    "                    model_ppo, input_ids[i].unsqueeze(0),\n",
    "                    attention_mask[i].unsqueeze(0), [lens_q[i]], [lens_a[i]])\n",
    "\n",
    "                #根据新旧概率求出变化率,进而求出loss\n",
    "                #根据target和value的差可以计算出另外一份loss\n",
    "                loss = self.get_loss(prob_log[i].unsqueeze(0),\n",
    "                                     value[i].unsqueeze(0), prob_log_new,\n",
    "                                     value_new, mask[i].unsqueeze(0),\n",
    "                                     delta[i].unsqueeze(0),\n",
    "                                     target[i].unsqueeze(0))\n",
    "\n",
    "                if not loss:\n",
    "                    continue\n",
    "\n",
    "                loss.backward()\n",
    "                #torch.nn.utils.clip_grad_norm_(model_ppo.parameters(), 1.0)\n",
    "                self.optimizer.step()\n",
    "                self.optimizer.zero_grad()\n",
    "\n",
    "    def batched_forward_pass(self, model, input_ids, attention_mask, lens_q,\n",
    "                             lens_a):\n",
    "        logits, value = model(input_ids=input_ids,\n",
    "                              attention_mask=attention_mask)\n",
    "\n",
    "        #取每个字的概率对数\n",
    "        prob_log = logprobs_from_logits(logits[:, :-1], input_ids[:, 1:])\n",
    "\n",
    "        #是预测结果并且不是PAD的位置是1\n",
    "        mask = torch.zeros_like(attention_mask)\n",
    "        mask[:, :-1] = attention_mask[:, 1:]\n",
    "        for i in range(len(input_ids)):\n",
    "            start = lens_q[i] - 1\n",
    "            end = start + lens_a[i]\n",
    "            mask[i, :start] = 0\n",
    "            mask[i, end:] = 0\n",
    "\n",
    "        #对最后一个字的预测没有意义,直接丢弃\n",
    "        value = value[:, :-1]\n",
    "        mask = mask[:, :-1]\n",
    "\n",
    "        return prob_log, value, mask\n",
    "\n",
    "    def compute_rewards(self, reward, prob_log, prob_log_ref, mask):\n",
    "        reward_kl = []\n",
    "\n",
    "        for i in range(len(reward)):\n",
    "            #求两份概率的kl散度\n",
    "            kl = get_kl(prob_log[i], prob_log_ref[i]) * -0.2\n",
    "\n",
    "            #把reward加在最后一个字的kl散度上\n",
    "            if (mask[i] == 0).all():\n",
    "                #print('all 0')\n",
    "                idx = 0\n",
    "            else:\n",
    "                idx = mask[i].nonzero()[-1].item()\n",
    "            kl[idx] += reward[i]\n",
    "\n",
    "            reward_kl.append(kl)\n",
    "\n",
    "        return torch.stack(reward_kl)\n",
    "\n",
    "    def compute_advantages(self, value, reward_kl, mask):\n",
    "        value = value * mask\n",
    "        reward_kl = reward_kl * mask\n",
    "\n",
    "        delta = []\n",
    "        lens = reward_kl.shape[1]\n",
    "\n",
    "        #从后往前遍历\n",
    "        for i in reversed(range(lens)):\n",
    "            #取下一时刻的value,如果已经是最后一个时刻,则value_next是0\n",
    "            #因为整个循环是从后往前,所以第0次是0,其他时刻取value\n",
    "            value_next = 0\n",
    "            if i < lens - 1:\n",
    "                value_next = value[:, i + 1]\n",
    "\n",
    "            #value = gamma*下一时刻的value + reward\n",
    "            #理论上相等,这里的差定义为delta,这里gamma是1,所以省略了\n",
    "            d = reward_kl[:, i] + value_next - value[:, i]\n",
    "\n",
    "            #取最后一个delta,如果还没有,则初始化为0\n",
    "            last_d = 0\n",
    "            if delta:\n",
    "                last_d = delta[-1]\n",
    "\n",
    "            #delta是从后往前传递的,这里的系数衡量了前后动作的因果关联性\n",
    "            delta.append(d + 0.95 * last_d)\n",
    "\n",
    "        #翻转顺序\n",
    "        delta = torch.stack(delta[::-1]).transpose(0, 1)\n",
    "\n",
    "        #定义target,它估计了理想的value值\n",
    "        target = delta + value\n",
    "        delta = masked_whiten(delta, mask)\n",
    "\n",
    "        return value, delta, target\n",
    "\n",
    "    def get_loss(self, prob_log, value, prob_log_new, value_new, mask, delta,\n",
    "                 target):\n",
    "\n",
    "        #对数概率,相除变相减,取exp后还原为商,即两个模型输出logits的变化率\n",
    "        ratio = (prob_log_new - prob_log).exp()\n",
    "\n",
    "        #如果变化率太过于剧烈,可能是发生了震荡,跳过\n",
    "        if masked_mean(ratio, mask).item() > 10:\n",
    "            #print('skip', masked_mean(ratio, mask).item())\n",
    "            return None\n",
    "\n",
    "        #先算两个value的loss,简单的算mse loss就可以了\n",
    "        loss_vf1 = (value_new - target)**2\n",
    "        #数值裁剪,很显然是为了缓解自举\n",
    "        loss_vf2 = clip_by_value(value_new, value - 0.2, value + 0.2)\n",
    "        loss_vf2 = (loss_vf2 - target)**2\n",
    "        #两份loss取大的,还是为了缓解自举\n",
    "        loss_vf = 0.5 * masked_mean(torch.max(loss_vf1, loss_vf2), mask)\n",
    "\n",
    "        #计算ppo loss\n",
    "        loss_surr1 = -delta * ratio\n",
    "        #数值裁剪,很显然是为了缓解自举\n",
    "        loss_surr2 = -delta * ratio.clamp(0.8, 1.2)\n",
    "        loss_surr = masked_mean(torch.max(loss_surr1, loss_surr2), mask)\n",
    "\n",
    "        return loss_surr + 0.1 * loss_vf\n",
    "\n",
    "\n",
    "trainer = PPOTrainer()\n",
    "\n",
    "trainer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "d99a7626",
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "model_cls = torch.load('cls.model')\n",
    "model_cls.to(device)\n",
    "\n",
    "for i in model_cls.parameters():\n",
    "    i.requires_grad_(False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "c3e7a327",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(tensor([0, 1, 2, 2, 2, 3, 3, 0, 3, 0, 1, 0, 1, 2, 2, 1, 3, 3, 1, 3, 3, 3, 0, 1,\n",
       "         0, 1, 1, 0, 3, 0, 1, 2, 3, 0, 1, 2, 0, 1, 0, 2, 3, 2, 2, 1, 3, 0, 3, 3,\n",
       "         3, 2, 1, 3, 2, 3, 1, 1, 0, 3, 2, 3, 0, 0, 1, 0], device='cuda:0'),\n",
       " [tensor([ 1, 44, 45, 50,  6, 10, 11,  8, 51], device='cuda:0'),\n",
       "  tensor([ 1, 45, 49, 50,  9,  6,  4, 11, 51], device='cuda:0'),\n",
       "  tensor([ 1, 48, 47, 50,  9,  7,  8,  8, 51], device='cuda:0'),\n",
       "  tensor([ 1, 48, 47, 50,  9,  4,  8,  6, 51], device='cuda:0'),\n",
       "  tensor([ 1, 48, 47, 50, 11,  4,  7, 10, 51], device='cuda:0'),\n",
       "  tensor([ 1, 46, 47, 50,  6, 11,  8,  6, 51], device='cuda:0'),\n",
       "  tensor([ 1, 46, 47, 50, 13,  7, 11,  8, 51], device='cuda:0'),\n",
       "  tensor([ 1, 44, 45, 50,  6,  5,  8, 10, 51], device='cuda:0'),\n",
       "  tensor([ 1, 46, 47, 50,  6,  4,  6,  4, 51], device='cuda:0'),\n",
       "  tensor([ 1, 44, 45, 50, 12,  8, 12,  8, 51], device='cuda:0')])"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "def get_question():\n",
    "    label, question, _ = tokenizer.get_batch_data(prefix=True)\n",
    "    label = torch.LongTensor(label).to(device)\n",
    "\n",
    "    #只要问题部分,等号后面的内容切除\n",
    "    question = [i[:i.index(tokenizer.encoder['=']) + 1] for i in question]\n",
    "    question = [torch.LongTensor(i).to(device) for i in question]\n",
    "\n",
    "    return label, question\n",
    "\n",
    "\n",
    "label, question = get_question()\n",
    "\n",
    "label, question[:10]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "b77e7d96",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[tensor([ 5,  4, 11,  5, 11,  4, 11,  5, 11,  4, 11,  5, 11,  4, 11,  5, 10,  2],\n",
       "        device='cuda:0'),\n",
       " tensor([ 6,  4, 12,  7,  4,  4, 12,  7,  4,  4, 12,  7,  4,  4, 12,  6, 12,  2],\n",
       "        device='cuda:0'),\n",
       " tensor([26, 25, 27, 31, 32, 25, 27, 31, 32, 25, 27, 31, 32, 25, 27, 31, 30,  2],\n",
       "        device='cuda:0'),\n",
       " tensor([26, 24, 25, 31, 24, 24, 25, 31, 24, 24, 25, 31, 24, 24, 25, 30, 32,  2],\n",
       "        device='cuda:0'),\n",
       " tensor([16, 22, 15, 18, 20, 22, 15, 18, 20, 22, 15, 18, 20, 22, 15, 18, 18,  2],\n",
       "        device='cuda:0'),\n",
       " tensor([35, 34, 43, 40, 43, 34, 43, 40, 43, 34, 43, 40, 43, 34, 43, 40, 42,  2],\n",
       "        device='cuda:0'),\n",
       " tensor([27, 31, 28, 33, 33, 31, 28, 33, 33, 31, 28, 33, 33, 31, 28, 33, 30,  2],\n",
       "        device='cuda:0'),\n",
       " tensor([32, 29, 31, 28, 32, 29, 31, 28, 32, 29, 31, 28, 32, 29, 31, 28,  2],\n",
       "        device='cuda:0'),\n",
       " tensor([12,  4, 12,  4, 12,  4, 12,  4, 12,  4, 12,  4, 12,  4, 12,  4,  2],\n",
       "        device='cuda:0'),\n",
       " tensor([ 7,  7, 13,  7, 13,  7, 13,  7, 13,  7, 13,  7, 13,  7, 13,  7, 10,  2],\n",
       "        device='cuda:0')]"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#如果question的长度确定,这里可以转换成批运算\n",
    "def get_answer(question):\n",
    "    answer = [generate(model_ppo.model_gen, i.unsqueeze(0)) for i in question]\n",
    "\n",
    "    #裁剪,只要生成的部分\n",
    "    answer = [a[0, len(q):] for q, a in zip(question, answer)]\n",
    "\n",
    "    return answer\n",
    "\n",
    "\n",
    "answer = get_answer(question)\n",
    "\n",
    "answer[:10]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "720b4c0f",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([ 3.6872, -1.1460,  3.8057,  3.9364, -1.4702,  3.7345, -0.4405, -0.6043,\n",
       "        -0.8919,  3.5087, -0.9511, -1.1840, -1.4771, -1.8334,  4.2692,  3.8740,\n",
       "        -0.8011, -0.5006, -1.1768,  3.8684, -0.8430, -1.2840, -1.3416, -1.3015,\n",
       "         3.8101, -1.0993, -1.4701,  3.8751, -0.8155, -0.6414, -0.9203, -0.4777,\n",
       "         3.4747, -1.2405,  3.7453,  4.3441, -1.3180, -0.9811, -1.2982, -0.5892,\n",
       "        -0.9671,  3.6851, -0.4399, -1.7880, -0.9857, -0.9828, -0.7380, -1.0874,\n",
       "        -1.2675, -1.4632,  3.6310,  3.3023, -1.1148,  3.3922, -1.2878,  3.8767,\n",
       "        -1.4827, -0.4335, -1.2032,  3.4332,  3.5496, -1.2879, -1.3275, -1.1620],\n",
       "       device='cuda:0')"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "def get_reward(question, answer, label):\n",
    "    token = [q.tolist() + a.tolist() for q, a in zip(question, answer)]\n",
    "\n",
    "    input_ids, attention_mask = tokenizer.batch_pad(token=token)\n",
    "    input_ids = torch.LongTensor(input_ids).to(device)\n",
    "    attention_mask = torch.LongTensor(attention_mask).to(device)\n",
    "\n",
    "    with torch.no_grad():\n",
    "        logits = model_cls(input_ids=input_ids, attention_mask=attention_mask)\n",
    "\n",
    "    return logits.gather(1, label.reshape(-1, 1)).squeeze(1)\n",
    "\n",
    "\n",
    "reward = get_reward(question, answer, label)\n",
    "\n",
    "reward"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "a1c70e8e",
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0 0.2149033397436142\n",
      "S字母:8964= etitotitotitotityE 3.619676351547241\n",
      "S小写:3435= qeurqeurqeurqeurpE -1.2818214893341064\n",
      "100 0.5845248699188232\n",
      "S字母:2815= 一一二六一一二六一一二六一一二六〇E -1.7823433876037598\n",
      "S小写:4382= 一七五二九七五二九七五二九七五二八E 4.079746246337891\n",
      "200 0.43839702010154724\n",
      "S小写:8092= 32371237123712368E -0.2243329882621765\n",
      "S大写:1042= 4068406840684068E -1.1130495071411133\n",
      "300 1.0763766765594482\n",
      "S数字:4186= 17745774577457744E 3.7182233333587646\n",
      "S字母:3719= 一四八七七四八七七四八七七四八七六E -1.4818665981292725\n",
      "400 1.4338525533676147\n",
      "S小写:3025= 一二一〇一二一〇一二一〇一二一〇〇E 4.044591903686523\n",
      "S字母:7551= 30207020702070204E -1.1103068590164185\n",
      "500 2.2621121406555176\n",
      "S字母:9061= 36247624762476244E -1.077667236328125\n",
      "S大写:1859= ureyureyureyureyE -1.058005928993225\n",
      "600 3.02119517326355\n",
      "S字母:7934= equeoqueoqueoqueyE 4.11110782623291\n",
      "S字母:1110= rrrprrrprrrprrrpE 3.9544923305511475\n",
      "700 3.7655293941497803\n",
      "S大写:5480= 贰壹玖贰贰壹玖贰贰壹玖贰贰壹玖贰零E 3.579922914505005\n",
      "S字母:5507= wwpepwpepwpepwpwiE 3.8950295448303223\n",
      "800 3.623467445373535\n",
      "S小写:4569= 一八二七七八二七七八二七七八二七六E 4.201150894165039\n",
      "S数字:2844= 11377137713771376E 3.7510595321655273\n",
      "900 3.7655224800109863\n",
      "S大写:7291= 贰玖壹陆陆玖壹陆陆玖壹陆陆玖壹陆肆E 3.3911094665527344\n",
      "S小写:5137= 二〇五五八〇五五八〇五五八〇五四八E 4.173751354217529\n",
      "1000 3.627218723297119\n",
      "S小写:3215= 一二八六一二八六一二八六一二八六〇E 3.9007530212402344\n",
      "S小写:8705= 三四八二三四八二三四八二三四八二〇E 3.5744049549102783\n",
      "1100 3.737043857574463\n",
      "S字母:9835= eoereoereoereoerpE 3.967334508895874\n",
      "S小写:7249= 二八九九八八九九八八九九八八九九六E 4.226598739624023\n",
      "1200 3.78059720993042\n",
      "S大写:4338= 壹柒叁伍叁柒叁伍叁柒叁伍叁柒叁伍贰E 3.652658462524414\n",
      "S小写:8322= 三三二九一三二九一三二九一三二八八E 4.050459861755371\n",
      "1300 3.6444921493530273\n",
      "S字母:9624= oirooirooirooiroyE 3.794785261154175\n",
      "S小写:6775= 二七一〇二七一〇二七一〇二七一〇〇E 3.90024471282959\n",
      "1400 3.6444501876831055\n",
      "S字母:1078= reqwreqwreqwreqwE 3.854517698287964\n",
      "S字母:3319= qewuuewuuewuuewuyE 3.735929489135742\n",
      "1500 3.79252552986145\n",
      "S大写:9311= 叁柒贰肆柒柒贰肆柒柒贰肆柒柒贰肆肆E 3.667511224746704\n",
      "S数字:3334= 13337333733373336E 3.7022907733917236\n",
      "1600 3.8058130741119385\n",
      "S字母:2006= ipwripwripwripwrE 3.838862419128418\n",
      "S小写:8268= 三三〇七五三〇七五三〇七五三〇七二E 4.06358003616333\n",
      "1700 3.683478832244873\n",
      "S大写:3870= 壹伍肆捌壹伍肆捌壹伍肆捌壹伍肆捌零E 3.5110630989074707\n",
      "S字母:6452= wtiqptiqptiqptipiE 3.606231451034546\n",
      "1800 3.716475248336792\n",
      "S小写:3754= 一五〇一七五〇一七五〇一七五〇一六E 3.910107135772705\n",
      "S大写:8529= 叁肆壹壹玖肆壹壹玖肆壹壹玖肆壹壹陆E 3.8759965896606445\n",
      "1900 3.7528605461120605\n",
      "S字母:6046= wrqiyrqiyrqiyrqirE 4.003557205200195\n",
      "S数字:7947= 31771177117711768E 3.7246463298797607\n"
     ]
    }
   ],
   "source": [
    "for epoch in range(2000):\n",
    "    label, question = get_question()\n",
    "    answer = get_answer(question)\n",
    "    reward = get_reward(question, answer, label)\n",
    "\n",
    "    trainer.step(question, answer, reward)\n",
    "\n",
    "    if epoch % 100 == 0:\n",
    "        print(epoch, reward.mean().item())\n",
    "        for _, q, a, r in zip(range(2), question, answer, reward):\n",
    "            q = tokenizer.decode(q.tolist())\n",
    "            a = tokenizer.decode(a.tolist())\n",
    "            r = r.item()\n",
    "            print(q, a, r)\n",
    "\n",
    "model_ppo.to('cpu')\n",
    "torch.save(model_ppo, 'ppo.model')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python [conda env:pt2]",
   "language": "python",
   "name": "conda-env-pt2-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
